from __future__ import annotations
import numpy as np
from dataclasses import dataclass
from typing import Literal, Tuple

@dataclass
class FitCfg:
    # Window for power-law fit
    fit_window_min: "str | float" = "6*sigma"      # can be numeric or like "k*sigma"
    fit_window_max_fracL: float = 0.25
    # Radial binning
    radial_bins_scheme: Literal["log","linear"] = "log"
    radial_bins: int = 36
    # Regression weighting
    regression_weights: Literal["counts","uniform"] = "counts"

def _parse_min_window(val, sigma: float) -> float:
    if isinstance(val, (int,float)):
        return float(val)
    if isinstance(val, str):
        s = val.strip().lower().replace(" ", "")
        # only allow "<num>*sigma"
        if s.endswith("*sigma"):
            try:
                coeff = float(s[:-6])
                return coeff * float(sigma)
            except Exception:
                pass
        try:
            return float(s)
        except Exception:
            pass
    # default fallback
    return 6.0 * float(sigma)

def _bin_edges(L: int, rmin: float, rmax: float, nbins: int, scheme: str) -> np.ndarray:
    rmin = max(1.0, float(rmin))
    rmax = min(float(rmax), 0.5*float(L) - 1.0)
    if rmax <= rmin + 1.0:
        rmax = rmin + 1.0
    if scheme == "linear":
        edges = np.linspace(rmin, rmax, nbins+1)
    else:
        # log
        edges = np.logspace(np.log10(max(1.0, rmin)), np.log10(max(rmin+1e-9, rmax)), nbins+1)
    return edges

def _radial_stats(arr: np.ndarray, cy: float, cx: float, edges: np.ndarray, use_abs: bool, min_count: int=10) -> Tuple[np.ndarray,np.ndarray,np.ndarray]:
    H, W = arr.shape
    y, x = np.indices(arr.shape)
    r = np.hypot(x - cx, y - cy)
    # allocate
    vals = np.zeros(len(edges)-1, dtype=float)
    cnts = np.zeros(len(edges)-1, dtype=float)
    for i in range(len(edges)-1):
        mask = (r >= edges[i]) & (r < edges[i+1])
        cnt = int(mask.sum())
        cnts[i] = cnt
        if cnt >= min_count:
            v = arr[mask]
            if use_abs:
                v = np.abs(v)
            vals[i] = float(np.nanmean(v))
        else:
            vals[i] = np.nan
    # bin center radii
    rmid = 0.5*(edges[:-1] + edges[1:])
    return rmid, vals, cnts

def _weighted_slope(x: np.ndarray, y: np.ndarray, w: np.ndarray) -> Tuple[float,float]:
    """Fit y = a + b x; return slope b and weighted R^2."""
    # guard
    m = np.isfinite(x) & np.isfinite(y) & np.isfinite(w) & (w > 0)
    x, y, w = x[m], y[m], w[m]
    if x.size < 3:
        return np.nan, np.nan
    # normalize weights
    w = w / w.sum()
    xbar = (w * x).sum()
    ybar = (w * y).sum()
    cov_xy = (w * (x - xbar) * (y - ybar)).sum()
    var_x  = (w * (x - xbar)**2).sum()
    if var_x <= 0:
        return np.nan, np.nan
    slope = cov_xy / var_x
    # intercept
    intercept = ybar - slope * xbar
    # R^2: 1 - SSE/SST (weighted)
    yhat = intercept + slope * x
    sse = (w * (y - yhat)**2).sum()
    sst = (w * (y - ybar)**2).sum()
    r2 = 1.0 - (sse / sst) if sst > 0 else np.nan
    return float(slope), float(r2)

def fit_slopes(phi_map: np.ndarray, grad_map: np.ndarray, sigma_max: float, cfg: FitCfg) -> dict:
    """
    Compute power-law slopes of |phi| and |∇phi| vs radius.
    - Center is chosen as the argmax of |phi| to mimic S+ centroid without requiring the mask.
    - Radial profiles use mean(|phi|) and mean(|∇phi|) per bin to ensure positivity for log-space fits.
    """
    H, W = phi_map.shape
    L = H
    # choose center at |phi| maximum
    idx = int(np.nanargmax(np.abs(phi_map)))
    cy, cx = divmod(idx, W)
    # fit window
    rmin = _parse_min_window(cfg.fit_window_min, float(sigma_max))
    rmax = cfg.fit_window_max_fracL * float(L)
    edges = _bin_edges(L, rmin, rmax, int(cfg.radial_bins), cfg.radial_bins_scheme)
    # radial profiles
    rmid, phi_prof, phi_cnt = _radial_stats(phi_map, cy, cx, edges, use_abs=True, min_count=10)
    _, grad_prof, grad_cnt = _radial_stats(grad_map, cy, cx, edges, use_abs=False, min_count=10)
    # weights
    if cfg.regression_weights == "uniform":
        w_phi = np.isfinite(phi_prof).astype(float)
        w_grad = np.isfinite(grad_prof).astype(float)
    else:
        w_phi = np.where(np.isfinite(phi_prof), phi_cnt, 0.0)
        w_grad = np.where(np.isfinite(grad_prof), grad_cnt, 0.0)
    # log space (add tiny epsilon for numerical safety)
    eps = 1e-18
    x = np.log(rmid + 1e-12)
    y_phi = np.log(phi_prof + eps)
    y_grad = np.log(grad_prof + eps)
    s_phi, r2_phi = _weighted_slope(x, y_phi, w_phi)
    s_grad, r2_grad = _weighted_slope(x, y_grad, w_grad)
    return dict(s_phi=s_phi, r2_phi=r2_phi, s_grad=s_grad, r2_grad=r2_grad)
